from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import numpy as np
import jax
import jax.numpy as jnp

from diffgro.utils import llm, print_r, print_b 


def calculate_grad(
    guide_fn: Callable,
    x_0_hat: jax.Array,
    obs_dim: int,
    **kwargs,
):
    if guide_fn is None:
        return None, {'loss': 0.0}
    
    loss, grad = jax.value_and_grad(guide_fn, argnums=(0,), has_aux=False)(x_0_hat, obs_dim, **kwargs)
    # gradient scaling
    return grad[0], {'loss': loss}

# ====================================================================================== #
# manually designed functions

@partial(jax.jit, static_argnums=(1,))
def _loss_abc(
    x_0_hat: jax.Array, # intermediate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    # loss = jnp.mean(jnp.square(act)) 
    # speed = jnp.linalg.norm(act[:,:-1,:] - act[:,1:,:], axis=-1)
    speed = jnp.linalg.norm(act, axis=-1)
    energy = jnp.sum(jnp.abs(act))
    # loss = jnp.mean(jnp.square(speed - 0.25))
    loss = jnp.mean(jnp.maximum(0.0, energy - 0.32))
    # loss = jnp.mean(energy)
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_imitation(
    x_0_hat: jax.Array, # intermediate diffusion output
    obs_dim: int,   # observation dimension
    y: jax.Array,   # goal trajectory
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    loss = jnp.mean(jnp.abs(x_0_hat - y))
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_slower(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    loss = jnp.mean(jnp.square(act)) 
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_faster(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    loss = -jnp.mean(jnp.square(act)) 
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_x_slower(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    loss = jnp.mean(jnp.square(act[:,:,0]))
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_x_faster(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    loss = jnp.mean(jnp.square(act[:,:,0]))
    return -loss

@partial(jax.jit, static_argnums=(1,))
def _loss_y_slower(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    loss = jnp.mean(jnp.square(act[:,:,1]))
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_y_faster(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    loss = jnp.mean(jnp.square(act[:,:,1]))
    return -loss

@partial(jax.jit, static_argnums=(1,))
def _loss_z_slower(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    loss = jnp.mean(jnp.square(act[:,:,2]))
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_limit(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    force = jnp.linalg.norm(act[:,:,:3])
    excess_force = jnp.clip(force - 0.25, 0, None)
    loss = jnp.sum(excess_force)
    return loss

@partial(jax.jit, static_argnums=(1,))
def _loss_energy(
    x_0_hat: jax.Array, # intermeditate diffusion output
    obs_dim: int,   # observation dimension
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
    obs, act = x_0_hat[:,:,:obs_dim], x_0_hat[:,:,obs_dim:-1] 
    energy = jnp.linalg.norm(act, axis=-1)
    energy = jnp.linalg.norm(energy[:,:-1], - energy[:,1:], axis=-1)
    # energy = jnp.linalg.norm(act[:,:-1,:] - act[:,1:,:], axis=-1)
    loss = jnp.mean(energy)
    return loss

# ====================================================================================== #
# LLM generatedd loss function

def _loss_llm(prompt: str) -> Callable:
    code = llm.chatgpt(prompt)
    idx1 = code.find('def _loss_fn')
    idx2 = code.find('return loss') + 11
    code = code[idx1:idx2]
    print_b(f'Generated Code at {idx1}-{idx2}:\n {code}')
    code = '@partial(jax.jit, static_argnums=(1,))\n' + code  # add jitting
    exec(code, globals())
    return _loss_fn, code

def _loss_txt(code: str) -> Callable:
    idx1 = code.find('def _loss_fn')
    idx2 = 0
    for c in code.split('\n'):
        idx2 = idx2 + len(c) + 1
        if 'return' in c:
            break
    # idx2 = code.find('return loss') + 11
    code = code[idx1:idx2]
    print_b(f'Generated Code at {idx1}-{idx2}:\n {code}')
    if "x[:,:,obs_dim:]" in code:
        print_r("REPLACE The Code !!")
        code = code.replace("x[:,:,obs_dim:]", "x[:,:,obs_dim:-1]")
    code = '@partial(jax.jit, static_argnums=(1,))\n' + code  # add jitting
    exec(code, globals())
    return _loss_fn, code

# ====================================================================================== #

def _manual_loss_fn(context_type, context_target):
    if context_type == 'speed below':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,obs_dim:-1]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, speed - {context_target}))\n\treturn loss"
    elif context_type == 'speed above':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,obs_dim:-1]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, {context_target} - speed))\n\treturn loss"
    elif context_type == 'x-axis faster':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,obs_dim:-1]\n\tspeed = jnp.linalg.norm(act[:,:,0:1], axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, {context_target} - speed))\n\treturn loss"
    elif context_type == 'y-axis faster':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,obs_dim:-1]\n\tspeed = jnp.linalg.norm(act[:,:,1:2], axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, {context_target} - speed))\n\treturn loss"
    elif context_type == 'speed above and below':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,obs_dim:-1]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, speed - {context_target[1]})) +  jnp.mean(jnp.maximum(0.0, {context_target[0]} - speed))\n\treturn loss" # below + above
    elif context_type == 'faster':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,obs_dim:-1]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss=-jnp.mean(speed)\n\treturn loss"
    elif context_type == 'slower':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,obs_dim:-1]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss=jnp.mean(speed)\n\treturn loss"
    else:
        raise NotImplementedError

    code = '@partial(jax.jit, static_argnums=(1,))\n' + code
    exec(code, globals())
    return _loss_fn, code


guide_fn_dict = {
    "test": {
        "abc": _loss_abc,
        "imitation": _loss_imitation,
        "slower": _loss_slower,
        "faster": _loss_faster,
        "x_slower": _loss_x_slower,
        "x_faster": _loss_x_faster,
        "y_slower": _loss_y_slower,
        "y_faster": _loss_y_faster,
        "z_slower": _loss_z_slower,
        "limit": _loss_limit,
        "energy": _loss_energy,
    },
    "blank": None,
    "manual": _manual_loss_fn,
    "llm": _loss_llm,
    "txt": _loss_txt,
}
